{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "HW08.ipynb",
      "provenance": [],
      "collapsed_sections": [
        "bDk9r2YOcDc9"
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YiVfKn-6tXz8"
      },
      "source": [
        "# **Homework 8 - Anomaly Detection**\n",
        "\n",
        "If there are any questions, please contact ntu-ml-2021spring-ta@googlegroups.com"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "w6s3wqonWoxH"
      },
      "source": [
        "# Mounting your gdrive (Optional)\n",
        "By mounting your gdrive, you can save and manage your data and models in your Google drive"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "dfqPBGxoWpkC",
        "outputId": "cca8d557-ebfa-4827-e581-eeb73a4e882c"
      },
      "source": [
        "from google.colab import drive\n",
        "drive.mount('/content/gdrive')\n",
        "import os\n",
        "\n",
        "# your workspace in your drive\n",
        "workspace = 'YOUR_WORKSPACE'\n",
        "\n",
        "\n",
        "try:\n",
        "  os.chdir(os.path.join('/content/gdrive/My Drive/', workspace))\n",
        "except:\n",
        "  os.mkdir(os.path.join('/content/gdrive/My Drive/', workspace))\n",
        "  os.chdir(os.path.join('/content/gdrive/My Drive/', workspace))"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Mounted at /content/gdrive\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bDk9r2YOcDc9"
      },
      "source": [
        "# Set up the environment\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Oi12tJMYWi0Q"
      },
      "source": [
        "## Package installation"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "7LexxyPWWjJB",
        "outputId": "e8d053f0-1587-4721-9e27-0adf60f7a280"
      },
      "source": [
        "# Training progress bar\n",
        "!pip install -q qqdm"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "  Building wheel for qqdm (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DCgNXSsEWuY7"
      },
      "source": [
        "## Downloading data\n",
        "**Please use the link according to the last digit of your student ID first!**\n",
        "\n",
        "If all download links fail, please follow [here](https://drive.google.com/drive/folders/13T0Pa_WGgQxNkqZk781qhc5T9-zfh19e).\n",
        "\n",
        "* To open the file using your browser, use the link below (replace the id first!):\n",
        "https://drive.google.com/file/d/REPLACE_WITH_ID\n",
        "* e.g. https://drive.google.com/file/d/15XWO-zI-AKW0igfwSydmwSGa8ENb9wCg"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "r_wwI4CSWvc2",
        "outputId": "0d5952d0-826e-4a73-a612-fdec5707009a"
      },
      "source": [
        "\n",
        "!gdown --id '15XWO-zI-AKW0igfwSydmwSGa8ENb9wCg' --output data-bin.tar.gz \n",
        "\n",
        "# Other download links\n",
        "#   Please uncomment the line according to the last digit of your student ID first\n",
        "\n",
        "# 0\n",
        "# !gdown --id '167SejKP7vLB2sbHfQHJii8-WisYoTmLH' --output data-bin.tar.gz \n",
        "\n",
        "# 1\n",
        "# !gdown --id '1BXJaeouaf4Zml2aeNlQfJ_AOcItTWcef' --output data-bin.tar.gz \n",
        "\n",
        "# 2\n",
        "# !gdown --id '1HkBPxhk-9rD0H_cen2YjLXxsvInkToBl' --output data-bin.tar.gz \n",
        "\n",
        "# 3\n",
        "# !gdown --id '1K_WGT8AD8iMsOSMYtK1Gp6vyEcRNCLQM' --output data-bin.tar.gz \n",
        "\n",
        "# 4\n",
        "# !gdown --id '1LGdyDUQA4EPaWTEUVm_upPAEl6qAh91Z' --output data-bin.tar.gz \n",
        "\n",
        "# 5\n",
        "# !gdown --id '1N9wNazaMy4A0UQ6pow5DXfVJ6abaiQxU' --output data-bin.tar.gz \n",
        "\n",
        "# 6\n",
        "# !gdown --id '1PC66MrDw-tnuYN2STauPg2FoJYm3_Yy5' --output data-bin.tar.gz \n",
        "\n",
        "# 7\n",
        "# !gdown --id '1mzy4E06CcBJc0udhPgL4zMhDlWibKbVs' --output data-bin.tar.gz \n",
        "\n",
        "# 8\n",
        "# !gdown --id '1zPbCF7whPv1Xs_2azwe1SUweomgLsVwH' --output data-bin.tar.gz \n",
        "\n",
        "# 9\n",
        "# !gdown --id '1Uc1Y8YYAwj7D0_wd0MeSX3szUiIB1rLU' --output data-bin.tar.gz "
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Downloading...\n",
            "From: https://drive.google.com/uc?id=15XWO-zI-AKW0igfwSydmwSGa8ENb9wCg\n",
            "To: /content/gdrive/My Drive/phucque/data-bin.tar.gz\n",
            "1.64GB [00:12, 129MB/s]\n",
            "Downloading...\n",
            "From: https://drive.google.com/uc?id=1mlKBtfSenGu96ny9YDb10l_cEu-DG-gs\n",
            "To: /content/gdrive/My Drive/phucque/testing_labels.npy\n",
            "100% 160k/160k [00:00<00:00, 23.3MB/s]\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1hui5EYoWzB4"
      },
      "source": [
        "## Untar data\n",
        "\n",
        "data-bin contains 2 files\n",
        "```\n",
        "data-bin/\n",
        "├── trainingset.npy\n",
        "├── testingset.npy\n",
        "...\n",
        "```"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "0K5kmlkuWzhJ",
        "outputId": "44d9c8e1-cda9-40ed-d85e-9aef6fc9d825"
      },
      "source": [
        "!tar zxvf data-bin.tar.gz\n",
        "!ls data-bin\n",
        "!ls data-bin\n",
        "!rm data-bin.tar.gz"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "data-bin/\n",
            "data-bin/testingset.npy\n",
            "data-bin/trainingset.npy\n",
            "testing_labels.npy  testingset.npy  trainingset.npy\n",
            "mv: cannot stat 'testing_labels.npy': No such file or directory\n",
            "testing_labels.npy  testingset.npy  trainingset.npy\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HNe7QU7n7cqh"
      },
      "source": [
        "# Import packages"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Jk3qFK_a7k8P"
      },
      "source": [
        "import numpy as np\n",
        "import random\n",
        "import torch\n",
        "\n",
        "from torch.utils.data import DataLoader\n",
        "from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,\n",
        "                              TensorDataset)\n",
        "import torchvision.transforms as transforms\n",
        "\n",
        "from torch import nn\n",
        "import torch.nn.functional as F\n",
        "from torch.autograd import Variable\n",
        "import torchvision.models as models\n",
        "\n",
        "from torch.optim import Adam, AdamW\n",
        "\n",
        "from sklearn.cluster import MiniBatchKMeans\n",
        "from scipy.cluster.vq import vq, kmeans\n",
        "\n",
        "from qqdm import qqdm, format_str\n",
        "import pandas as pd\n",
        "\n",
        "import pdb  # use pdb.set_trace() to set breakpoints for debugging\n",
        "\n",
        "\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6X6fkGPnYyaF"
      },
      "source": [
        "# Loading data"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "k7Wd4yiUYzAm",
        "outputId": "28aa0306-c4fc-4ba1-ec75-9f611b1b3304"
      },
      "source": [
        "\n",
        "train = np.load('data-bin/trainingset.npy', allow_pickle=True)\n",
        "test = np.load('data-bin/testingset.npy', allow_pickle=True)\n",
        "\n",
        "print(train.shape)\n",
        "print(test.shape)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "(140001, 64, 64, 3)\n",
            "(19999, 64, 64, 3)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_flpmj6OYIa6"
      },
      "source": [
        "## Random seed\n",
        "Set the random seed to a certain value for reproducibility."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Gb-dgXQYYI2Q"
      },
      "source": [
        "def same_seeds(seed):\n",
        "    # Python built-in random module\n",
        "    random.seed(seed)\n",
        "    # Numpy\n",
        "    np.random.seed(seed)\n",
        "    # Torch\n",
        "    torch.manual_seed(seed)\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.manual_seed(seed)\n",
        "        torch.cuda.manual_seed_all(seed)\n",
        "    torch.backends.cudnn.benchmark = False\n",
        "    torch.backends.cudnn.deterministic = True\n",
        "\n",
        "same_seeds(19530615)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zR9zC0_Df-CR"
      },
      "source": [
        "# Autoencoder"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1EbfwRREhA7c"
      },
      "source": [
        "# Models & loss"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xUUIwFm02qEL"
      },
      "source": [
        "Lecture video：https://www.youtube.com/watch?v=6W8FqUGYyDo&list=PLJV_el3uVTsOK_ZK5L0Iv_EQoL1JefRL4&index=8"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lESPM7C9svYa"
      },
      "source": [
        "fcn_autoencoder and vae are from https://github.com/L1aoXingyu/pytorch-beginner"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GwG2XzgWs_JR"
      },
      "source": [
        "conv_autoencoder is from https://github.com/jellycsc/PyTorch-CIFAR-10-autoencoder/"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Wi8ds1fugCkR"
      },
      "source": [
        "class fcn_autoencoder(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(fcn_autoencoder, self).__init__()\n",
        "        self.encoder = nn.Sequential(\n",
        "            nn.Linear(64 * 64 * 3, 128),\n",
        "            nn.ReLU(True),\n",
        "            nn.Linear(128, 64),\n",
        "            nn.ReLU(True), \n",
        "            nn.Linear(64, 12), \n",
        "            nn.ReLU(True), \n",
        "            nn.Linear(12, 3))\n",
        "        \n",
        "        self.decoder = nn.Sequential(\n",
        "            nn.Linear(3, 12),\n",
        "            nn.ReLU(True),\n",
        "            nn.Linear(12, 64),\n",
        "            nn.ReLU(True),\n",
        "            nn.Linear(64, 128),\n",
        "            nn.ReLU(True), \n",
        "            nn.Linear(128, 64 * 64 * 3), \n",
        "            nn.Tanh())\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.encoder(x)\n",
        "        x = self.decoder(x)\n",
        "        return x\n",
        "\n",
        "\n",
        "# maybe it can be smaller\n",
        "class conv_autoencoder(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(conv_autoencoder, self).__init__()\n",
        "        self.encoder = nn.Sequential(\n",
        "            nn.Conv2d(3, 12, 4, stride=2, padding=1),         \n",
        "            nn.ReLU(),\n",
        "            nn.Conv2d(12, 24, 4, stride=2, padding=1),        \n",
        "            nn.ReLU(),\n",
        "\t\t\t      nn.Conv2d(24, 48, 4, stride=2, padding=1),         \n",
        "            nn.ReLU(),\n",
        "            nn.Conv2d(48, 96, 4, stride=2, padding=1),   # medium: remove this layer\n",
        "            nn.ReLU(),\n",
        "        )\n",
        "        self.decoder = nn.Sequential(\n",
        "            nn.ConvTranspose2d(96, 48, 4, stride=2, padding=1), # medium: remove this layer\n",
        "            nn.ReLU(),\n",
        "\t\t\t      nn.ConvTranspose2d(48, 24, 4, stride=2, padding=1), \n",
        "            nn.ReLU(),\n",
        "\t\t\t      nn.ConvTranspose2d(24, 12, 4, stride=2, padding=1), \n",
        "            nn.ReLU(),\n",
        "            nn.ConvTranspose2d(12, 3, 4, stride=2, padding=1),\n",
        "            nn.Tanh(),\n",
        "        )\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.encoder(x)\n",
        "        x = self.decoder(x)\n",
        "        return x\n",
        "\n",
        "\n",
        "class VAE(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(VAE, self).__init__()\n",
        "        self.encoder = nn.Sequential(\n",
        "            nn.Conv2d(3, 12, 4, stride=2, padding=1),            \n",
        "            nn.ReLU(),\n",
        "            nn.Conv2d(12, 24, 4, stride=2, padding=1),    \n",
        "            nn.ReLU(),\n",
        "        )\n",
        "        self.enc_out_1 = nn.Sequential(\n",
        "            nn.Conv2d(24, 48, 4, stride=2, padding=1),  \n",
        "            nn.ReLU(),\n",
        "        )\n",
        "        self.enc_out_2 = nn.Sequential(\n",
        "            nn.Conv2d(24, 48, 4, stride=2, padding=1),\n",
        "            nn.ReLU(),\n",
        "        )\n",
        "        self.decoder = nn.Sequential(\n",
        "\t\t\t      nn.ConvTranspose2d(48, 24, 4, stride=2, padding=1), \n",
        "            nn.ReLU(),\n",
        "\t\t\t      nn.ConvTranspose2d(24, 12, 4, stride=2, padding=1), \n",
        "            nn.ReLU(),\n",
        "            nn.ConvTranspose2d(12, 3, 4, stride=2, padding=1), \n",
        "            nn.Tanh(),\n",
        "        )\n",
        "\n",
        "    def encode(self, x):\n",
        "        h1 = self.encoder(x)\n",
        "        return self.enc_out_1(h1), self.enc_out_2(h1)\n",
        "\n",
        "    def reparametrize(self, mu, logvar):\n",
        "        std = logvar.mul(0.5).exp_()\n",
        "        if torch.cuda.is_available():\n",
        "            eps = torch.cuda.FloatTensor(std.size()).normal_()\n",
        "        else:\n",
        "            eps = torch.FloatTensor(std.size()).normal_()\n",
        "        eps = Variable(eps)\n",
        "        return eps.mul(std).add_(mu)\n",
        "\n",
        "    def decode(self, z):\n",
        "        return self.decoder(z)\n",
        "\n",
        "    def forward(self, x):\n",
        "        mu, logvar = self.encode(x)\n",
        "        z = self.reparametrize(mu, logvar)\n",
        "        return self.decode(z), mu, logvar\n",
        "\n",
        "\n",
        "def loss_vae(recon_x, x, mu, logvar, criterion):\n",
        "    \"\"\"\n",
        "    recon_x: generating images\n",
        "    x: origin images\n",
        "    mu: latent mean\n",
        "    logvar: latent log variance\n",
        "    \"\"\"\n",
        "    mse = criterion(recon_x, x)  # mse loss\n",
        "    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)\n",
        "    KLD = torch.sum(KLD_element).mul_(-0.5)\n",
        "    return mse + KLD\n",
        "\n",
        "\n",
        "class Resnet(nn.Module):\n",
        "    def __init__(self, fc_hidden1=1024, fc_hidden2=768, drop_p=0.3, CNN_embed_dim=256):\n",
        "        super(Resnet, self).__init__()\n",
        "\n",
        "        self.fc_hidden1, self.fc_hidden2, self.CNN_embed_dim = fc_hidden1, fc_hidden2, CNN_embed_dim\n",
        "\n",
        "        # CNN architechtures\n",
        "        self.ch1, self.ch2, self.ch3, self.ch4 = 16, 32, 64, 128\n",
        "        self.k1, self.k2, self.k3, self.k4 = (5, 5), (3, 3), (3, 3), (3, 3)      # 2d kernal size\n",
        "        self.s1, self.s2, self.s3, self.s4 = (2, 2), (2, 2), (2, 2), (2, 2)      # 2d strides\n",
        "        self.pd1, self.pd2, self.pd3, self.pd4 = (0, 0), (0, 0), (0, 0), (0, 0)  # 2d padding\n",
        "\n",
        "        # encoding components\n",
        "        resnet = models.resnet18(pretrained=False)\n",
        "        modules = list(resnet.children())[:-1]      # delete the last fc layer.\n",
        "        self.resnet = nn.Sequential(*modules)\n",
        "        self.fc1 = nn.Linear(resnet.fc.in_features, self.fc_hidden1)\n",
        "        self.bn1 = nn.BatchNorm1d(self.fc_hidden1, momentum=0.01)\n",
        "        self.fc2 = nn.Linear(self.fc_hidden1, self.fc_hidden2)\n",
        "        self.bn2 = nn.BatchNorm1d(self.fc_hidden2, momentum=0.01)\n",
        "\n",
        "        self.fc3_mu = nn.Linear(self.fc_hidden2, self.CNN_embed_dim)      # output = CNN embedding latent variables\n",
        "\n",
        "        # Sampling vector\n",
        "        self.fc4 = nn.Linear(self.CNN_embed_dim, self.fc_hidden2)\n",
        "        self.fc_bn4 = nn.BatchNorm1d(self.fc_hidden2)\n",
        "        self.fc5 = nn.Linear(self.fc_hidden2, 64 * 4 * 4)\n",
        "        self.fc_bn5 = nn.BatchNorm1d(64 * 4 * 4)\n",
        "        self.relu = nn.ReLU(inplace=True)\n",
        "\n",
        "        # Decoder\n",
        "        self.convTrans6 = nn.Sequential(\n",
        "            nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=self.k4, stride=self.s4,\n",
        "                               padding=self.pd4),\n",
        "            nn.BatchNorm2d(32, momentum=0.01),\n",
        "            nn.ReLU(inplace=True),\n",
        "        )\n",
        "        self.convTrans7 = nn.Sequential(\n",
        "            nn.ConvTranspose2d(in_channels=32, out_channels=8, kernel_size=self.k3, stride=self.s3,\n",
        "                               padding=self.pd3),\n",
        "            nn.BatchNorm2d(8, momentum=0.01),\n",
        "            nn.ReLU(inplace=True),\n",
        "        )\n",
        "\n",
        "        self.convTrans8 = nn.Sequential(\n",
        "            nn.ConvTranspose2d(in_channels=8, out_channels=3, kernel_size=self.k2, stride=self.s2,\n",
        "                               padding=self.pd2),\n",
        "            nn.BatchNorm2d(3, momentum=0.01),\n",
        "            nn.Sigmoid()    # y = (y1, y2, y3) \\in [0 ,1]^3\n",
        "        )\n",
        "\n",
        "\n",
        "    def encode(self, x):\n",
        "        x = self.resnet(x)  # ResNet\n",
        "        x = x.view(x.size(0), -1)  # flatten output of conv\n",
        "\n",
        "        # FC layers\n",
        "        if x.shape[0] > 1:\n",
        "            x = self.bn1(self.fc1(x))\n",
        "        else:\n",
        "            x = self.fc1(x)\n",
        "        x = self.relu(x)\n",
        "        if x.shape[0] > 1:\n",
        "            x = self.bn2(self.fc2(x))\n",
        "        else:\n",
        "            x = self.fc2(x)\n",
        "        x = self.relu(x)\n",
        "        x = self.fc3_mu(x)\n",
        "        return x\n",
        "\n",
        "    def decode(self, z):\n",
        "        if z.shape[0] > 1:\n",
        "            x = self.relu(self.fc_bn4(self.fc4(z)))\n",
        "            x = self.relu(self.fc_bn5(self.fc5(x))).view(-1, 64, 4, 4)\n",
        "        else:\n",
        "            x = self.relu(self.fc4(z))\n",
        "            x = self.relu(self.fc5(x)).view(-1, 64, 4, 4)\n",
        "        x = self.convTrans6(x)\n",
        "        x = self.convTrans7(x)\n",
        "        x = self.convTrans8(x)\n",
        "        x = F.interpolate(x, size=(64, 64), mode='bilinear', align_corners=True)\n",
        "        return x\n",
        "\n",
        "    def forward(self, x):\n",
        "        z = self.encode(x)\n",
        "        x_reconst = self.decode(z)\n",
        "\n",
        "        return x_reconst\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vrJ9bScg9AgO"
      },
      "source": [
        "# Dataset module\n",
        "\n",
        "Module for obtaining and processing data. The transform function here normalizes image's pixels from [0, 255] to [-1.0, 1.0].\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "33fWhE-h9LPq"
      },
      "source": [
        "class CustomTensorDataset(TensorDataset):\n",
        "    \"\"\"TensorDataset with support of transforms.\n",
        "    \"\"\"\n",
        "    def __init__(self, tensors):\n",
        "        self.tensors = tensors\n",
        "        if tensors.shape[-1] == 3:\n",
        "            self.tensors = tensors.permute(0, 3, 1, 2)\n",
        "        \n",
        "        self.transform = transforms.Compose([\n",
        "                            transforms.Lambda(lambda x: x.to(torch.float32)),\n",
        "                            transforms.Lambda(lambda x: 2. * x/255. - 1.),\n",
        "                            # transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),\n",
        "                            ])\n",
        "        \n",
        "    def __getitem__(self, index):\n",
        "        x = self.tensors[index]\n",
        "        \n",
        "        if self.transform:\n",
        "            # mapping images to [-1.0, 1.0]\n",
        "            x = self.transform(x)\n",
        "\n",
        "        return x\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.tensors)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XKNUImqUhIeq"
      },
      "source": [
        "# Training"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7ebAJdjFmS08"
      },
      "source": [
        "## Initialize\n",
        "- hyperparameters\n",
        "- dataloader\n",
        "- model\n",
        "- optimizer & loss\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "in7yLfmqtZTk"
      },
      "source": [
        "# Training hyperparameters\n",
        "num_epochs = 50\n",
        "batch_size = 10000 # medium: smaller batchsize\n",
        "learning_rate = 1e-3\n",
        "\n",
        "# Build training dataloader\n",
        "x = torch.from_numpy(train)\n",
        "train_dataset = CustomTensorDataset(x)\n",
        "\n",
        "train_sampler = RandomSampler(train_dataset)\n",
        "train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size)\n",
        "\n",
        "# Model\n",
        "model_type = 'cnn'   # selecting a model type from {'cnn', 'fcn', 'vae', 'resnet'}\n",
        "model_classes = {'resnet': Resnet(), 'fcn':fcn_autoencoder(), 'cnn':conv_autoencoder(), 'vae':VAE(), }\n",
        "model = model_classes[model_type].cuda()\n",
        "\n",
        "# Loss and optimizer\n",
        "criterion = nn.MSELoss()\n",
        "optimizer = torch.optim.Adam(\n",
        "    model.parameters(), lr=learning_rate)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wyooN-JPm8sS"
      },
      "source": [
        "## Training loop"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JoW1UrrxgI_U"
      },
      "source": [
        "\n",
        "best_loss = np.inf\n",
        "model.train()\n",
        "\n",
        "qqdm_train = qqdm(range(num_epochs), desc=format_str('bold', 'Description'))\n",
        "for epoch in qqdm_train:\n",
        "    tot_loss = list()\n",
        "    for data in train_dataloader:\n",
        "\n",
        "        # ===================loading=====================\n",
        "        if model_type in ['cnn', 'vae', 'resnet']:\n",
        "            img = data.float().cuda()\n",
        "        elif model_type in ['fcn']:\n",
        "            img = data.float().cuda()\n",
        "            img = img.view(img.shape[0], -1)\n",
        "\n",
        "        # ===================forward=====================\n",
        "        output = model(img)\n",
        "        if model_type in ['vae']:\n",
        "            loss = loss_vae(output[0], img, output[1], output[2], criterion)\n",
        "        else:\n",
        "            loss = criterion(output, img)\n",
        "\n",
        "        tot_loss.append(loss.item())\n",
        "        # ===================backward====================\n",
        "        optimizer.zero_grad()\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "    # ===================save_best====================\n",
        "    mean_loss = np.mean(tot_loss)\n",
        "    if mean_loss < best_loss:\n",
        "        best_loss = mean_loss\n",
        "        torch.save(model, 'best_model_{}.pt'.format(model_type))\n",
        "    # ===================log========================\n",
        "    qqdm_train.set_infos({\n",
        "      'epoch': f'{epoch + 1:.0f}/{num_epochs:.0f}',\n",
        "      'loss': f'{mean_loss:.4f}',\n",
        "    })\n",
        "    # ===================save_last========================\n",
        "    torch.save(model, 'last_model_{}.pt'.format(model_type))\n",
        "\n",
        "\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Wk0UxFuchLzR"
      },
      "source": [
        "# Inference\n",
        "Model is loaded and generates its anomaly score predictions."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "evgMW3OwoGqD"
      },
      "source": [
        "## Initialize\n",
        "- dataloader\n",
        "- model\n",
        "- prediction file"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_MBnXAswoKmq"
      },
      "source": [
        "eval_batch_size = 200\n",
        "\n",
        "# build testing dataloader\n",
        "data = torch.tensor(test, dtype=torch.float32)\n",
        "test_dataset = CustomTensorDataset(data)\n",
        "test_sampler = SequentialSampler(test_dataset)\n",
        "test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=eval_batch_size, num_workers=1)\n",
        "eval_loss = nn.MSELoss(reduction='none')\n",
        "\n",
        "# load trained model\n",
        "checkpoint_path = 'CHECKPOINT.pt'\n",
        "model = torch.load(checkpoint_path)\n",
        "model.eval()\n",
        "\n",
        "# prediction file \n",
        "out_file = 'PREDICTION_FILE.csv'"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "r1PS_ApzhfOQ"
      },
      "source": [
        "    \n",
        "anomality = list()\n",
        "with torch.no_grad():\n",
        "    for i, data in enumerate(test_dataloader): \n",
        "        if model_type in ['cnn', 'vae', 'resnet']:\n",
        "            img = data.float().cuda()\n",
        "        elif model_type in ['fcn']:\n",
        "            img = data.float().cuda()\n",
        "            img = img.view(img.shape[0], -1)\n",
        "        else:\n",
        "            img = data[0].cuda()\n",
        "        output = model(img)\n",
        "        if model_type in ['cnn', 'resnet', 'fcn']:\n",
        "            output = output\n",
        "        elif model_type in ['res_vae']:\n",
        "            output = output[0]\n",
        "        elif model_type in ['vae']: # , 'vqvae'\n",
        "            output = output[0]\n",
        "        if model_type in ['fcn']:\n",
        "            loss = eval_loss(output, img).sum(-1)\n",
        "        else:\n",
        "            loss = eval_loss(output, img).sum([1, 2, 3])\n",
        "        anomality.append(loss)\n",
        "anomality = torch.cat(anomality, axis=0)\n",
        "anomality = torch.sqrt(anomality).reshape(len(test), 1).cpu().numpy()\n",
        "\n",
        "df = pd.DataFrame(anomality, columns=['Predicted'])\n",
        "df.to_csv(out_file, index_label = 'Id')\n",
        "\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ATBy-JhH7TUt"
      },
      "source": [
        "# Training statistics\n",
        "- Number of parameters\n",
        "- Training time on colab\n",
        "- Training curve of the bossbaseline model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HxVv7IYfuTHE"
      },
      "source": [
        "- Simple\n",
        " - Number of parameters: 3176419\n",
        " - Training time on colab: ~ 30 min\n",
        "- Medium\n",
        " - Number of parameters: 47355\n",
        " - Training time on colab: ~ 30 min\n",
        "- Strong\n",
        " - Number of parameters: 47595\n",
        " - Training time on colab:  4 ~ 5 hrs\n",
        "- Boss:  \n",
        " - Number of parameters: 4364140\n",
        " - Training time on colab: 1.5~3 hrs\n",
        "\n",
        " ![Screen Shot 2021-04-29 at 16.43.57.png]()\n"
      ]
    }
  ]
}